package org.zjvis.datascience.service;

import com.alibaba.fastjson.JSONObject;
import javax.annotation.PostConstruct;
import org.apache.hadoop.yarn.webapp.hamlet2.Hamlet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.util.db.JDBCUtil;
import org.zjvis.datascience.service.dag.TaskRunnerResult;
import org.zjvis.datascience.service.dataprovider.GPDataProvider;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.Statement;

/**
 * @description UDF 用户自定义GP方法服务 Service
 * @date 2021-11-29
 */
@Service
public class UDFService {
    private final static Logger logger = LoggerFactory.getLogger(SemanticService.class);

    @Autowired
    private GPDataProvider gpDataProvider;

    @Autowired
    private OperatorTemplateService operatorTemplateService;

//    @PostConstruct
//    public void init() throws Exception {
//        operatorTemplateService.registerUDF(reverseGeocodeFcn());
//    }

    private TaskRunnerResult getRetFromDB(String sql){
        Connection conn = null;
        Statement st = null;
        ResultSet rs = null;
        int status = 0;
        String output = "";
        try {
            conn = gpDataProvider.getConn(1L);
            st = conn.createStatement();
            rs = st.executeQuery(sql);
            while (rs.next()) {
                output = rs.getString(1);
                break;
            }
            JSONObject jsonObject = JSONObject.parseObject(output);
            status = jsonObject.getInteger("status");
        } catch (Exception e) {
        } finally {
            JDBCUtil.close(conn, st, rs);
        }
        TaskRunnerResult result = new TaskRunnerResult(status, output);
        return result;
    }

    public String runPyFcn(String fcnName, String args){
        String sql = String.format("select * from \"pipeline\".%s(%s)", fcnName, args);
        TaskRunnerResult runnerResult = getRetFromDB(sql);
        String ret = runnerResult.getOutput();
        ret = ret.replace('\'', '"');
        ret = ret.replace("nan", "null");
        return ret;
    }

    public static String reverseGeocodeFcn() {
        String fcnName = "\"sys_func_reverse_geocode\"";
        String args = "\"coordinate\" text, \"semantic\" varchar";
        String body =
            "import reverse_geocode\n"
                + "if __name__ == \"__main__\":\n"
                + "    try:\n"
                + "        coord_split = coordinate.split(',')\n"
                + "        coord = (float(coord_split[1]), float(coord_split[0]))\n"
                + "\n"
                + "        if semantic == 'province':\n"
                + "            ret = reverse_geocode.get(coord)['city']\n"
                + "            sql = \"select province from dataset._city_mapper_ where en = '%s'\" % (ret)\n"
                + "            rs = plpy.execute(sql)\n"
                + "            ret = rs[0]['province']\n"
                + "        else:\n"
                + "            ret = reverse_geocode.get(coord)[semantic]\n"
                + "            if semantic == 'country' and (ret == 'Taiwan' or ret == 'Hong Kong' or ret == 'Macao'):\n"
                + "                ret = 'China'\n"
                + "    except:\n"
                + "        ret = None\n"
                + "    return ret";
        String script = String.format("CREATE OR REPLACE FUNCTION \"pipeline\".%s(%s) \n" +
            "    RETURNS \"pg_catalog\".\"text\" AS $BODY$\n" +
            "%s\n" +
            "$BODY$\n" +
            "  LANGUAGE plpythonu VOLATILE\n" +
            "  COST 100\n", fcnName, args, body);
        return script;
    }
}
